/*
 * SPDX-FileCopyrightText: 2025 kvcache-ai
 * SPDX-FileCopyrightText: 2025 Qingcheng.AI
 *
 * SPDX-License-Identifier: Apache-2.0
 */

/**
 * This file has adaption of open-source code from the following sources:
 * - https://github.com/kvcache-ai/ktransformers, licensed under Apache 2.0.
 */

#ifndef CPUINFER_OPERATOR_MOE_H
#define CPUINFER_OPERATOR_MOE_H

#include <cmath>
#include <cstdio>
#include <functional>
#include <mutex>
#include <vector>

#include "conversion.h"
#include "cpuinfer.h"
#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct MOEConfig {
    int expert_num;
    int routed_expert_num;
    int hidden_size;
    int intermediate_size;
    int stride;
    int group_min_len;
    int group_max_len;
    void *gate_proj;
    void *up_proj;
    void *down_proj;
    ggml_type gate_type;
    ggml_type up_type;
    ggml_type down_type;
    ggml_type hidden_type;

    MOEConfig() {}

    MOEConfig(int expert_num, int routed_expert_num, int hidden_size,
              int intermediate_size, int stride, int group_min_len,
              int group_max_len, void *gate_proj, void *up_proj,
              void *down_proj, ggml_type gate_type, ggml_type up_type,
              ggml_type down_type, ggml_type hidden_type)
        : expert_num(expert_num), routed_expert_num(routed_expert_num),
          hidden_size(hidden_size), intermediate_size(intermediate_size),
          stride(stride), group_min_len(group_min_len),
          group_max_len(group_max_len), gate_proj(gate_proj), up_proj(up_proj),
          down_proj(down_proj), gate_type(gate_type), up_type(up_type),
          down_type(down_type), hidden_type(hidden_type) {}
};

class MOE {
  public:
    MOE(MOEConfig);
    ~MOE();
    void warm_up(CPUInfer *CPUInfer);
    void forward_one(int k, const uint64_t *expert_ids, const float *weights,
                     const void *input, void *output, CPUInfer *CPUInfer);
    void forward_many(int qlen, int k, const uint64_t *expert_ids,
                      const float *weights, const void *input, void *output,
                      CPUInfer *CPUInfer);
    void forward(int qlen, int k, const uint64_t *expert_ids,
                 const float *weights, const void *input, void *output,
                 CPUInfer *CPUInfer);

  private:
    MOEConfig config_;
    void *gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
                      // quantized)]
    void *up_proj_;   // [expert_num * intermediate_size * hidden_size ( /32 if
                      // quantized)]
    void *down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if
                      // quantized)]

    float *s_input_fp32_; // [hidden_size]
    uint8_t *
        s_gate_input_; // [hidden_size *
                       // ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)
                       // /
                       // ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
    uint8_t *
        s_up_input_; // [hidden_size *
                     // ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type)
                     // /
                     // ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
    std::vector<float *>
        s_gate_output_;                // [routed_expert_num, intermediate_size]
    std::vector<float *> s_up_output_; // [routed_expert_num, intermediate_size]
    std::vector<float *>
        s_intermediate_fp32_; // [routed_expert_num, intermediate_size]
    std::vector<uint8_t *>
        s_down_input_; // [routed_expert_num, intermediate_size *
                       // ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type)
                       // /
                       // ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]
    std::vector<float *> s_down_output_; // [routed_expert_num, hidden_size]
    float *s_output_fp32_;               // [hidden_size]

    std::vector<float *> m_input_fp32_; // [group_max_len, hidden_size]
    std::vector<uint8_t *>
        m_gate_input_; // [group_max_len, hidden_size *
                       // ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)
                       // /
                       // ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
    std::vector<uint8_t *>
        m_up_input_; // [group_max_len, hidden_size *
                     // ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type)
                     // /
                     // ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
    uint8_t *
        m_local_gate_input_; // [routed_expert_num * group_max_len *
                             // hidden_size *
                             // ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)
                             // /
                             // ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
    uint8_t *
        m_local_up_input_; // [routed_expert_num * group_max_len * hidden_size *
                           // ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type)
                           // /
                           // ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
    float *m_local_gate_output_;       // [routed_expert_num * group_max_len *
                                       // intermediate_size]
    float *m_local_up_output_;         // [routed_expert_num * group_max_len *
                                       // intermediate_size]
    float *m_local_intermediate_fp32_; // [routed_expert_num * group_max_len *
                                       // intermediate_size]
    uint8_t *
        m_local_down_input_; // [routed_expert_num * group_max_len *
                             // intermediate_size *
                             // ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type)
                             // /
                             // ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]
    float *m_local_down_output_;         // [routed_expert_num * group_max_len *
                                         // hidden_size]
    std::vector<float *> m_output_fp32_; // [group_max_len, hidden_size]

    std::vector<std::vector<int>>
        m_local_pos_;              // [group_max_len, routed_expert_num]
    std::vector<int> m_local_num_; // [expert_num]
    std::vector<uint8_t *> m_local_gate_input_ptr_;      // [expert_num]
    std::vector<uint8_t *> m_local_up_input_ptr_;        // [expert_num]
    std::vector<float *> m_local_gate_output_ptr_;       // [expert_num]
    std::vector<float *> m_local_up_output_ptr_;         // [expert_num]
    std::vector<float *> m_local_intermediate_fp32_ptr_; // [expert_num]
    std::vector<uint8_t *> m_local_down_input_ptr_;      // [expert_num]
    std::vector<float *> m_local_down_output_ptr_;       // [expert_num]
};

#endif
